{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:56.243473900Z",
     "start_time": "2023-05-04T20:36:56.121515500Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import Tensor\n",
    "from pytorch_toolbelt.modules import encoders, decoders, heads\n",
    "from pytorch_toolbelt.modules import ACT_RELU, ACT_SILU\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Creating Encoder"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "FeatureMapsSpecification(channels=(64, 64, 128, 256, 512), strides=(2, 4, 8, 16, 32))"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_toolbelt.modules import encoders\n",
    "\n",
    "encoder = encoders.Resnet34Encoder(pretrained=True, layers=[0, 1, 2, 3, 4])\n",
    "output_spec = encoder.get_output_spec()\n",
    "output_spec"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:56.477345100Z",
     "start_time": "2023-05-04T20:36:56.135517200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "data": {
      "text/plain": "[{'size': (1, 64, 128, 128),\n  'mean': 0.2657899856567383,\n  'std': 0.30636727809906006,\n  'dtype': torch.float32},\n {'size': (1, 64, 64, 64),\n  'mean': 0.7731860876083374,\n  'std': 0.6579753756523132,\n  'dtype': torch.float32},\n {'size': (1, 128, 32, 32),\n  'mean': 0.2784508764743805,\n  'std': 0.33034390211105347,\n  'dtype': torch.float32},\n {'size': (1, 256, 16, 16),\n  'mean': 0.10355071723461151,\n  'std': 0.21384787559509277,\n  'dtype': torch.float32},\n {'size': (1, 512, 8, 8),\n  'mean': 0.9371781349182129,\n  'std': 1.2523103952407837,\n  'dtype': torch.float32}]"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pytorch_toolbelt.utils import describe_outputs\n",
    "\n",
    "outputs = encoder(torch.randn(1, 3, 256, 256))\n",
    "describe_outputs(outputs)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:56.604150900Z",
     "start_time": "2023-05-04T20:36:56.476344400Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Changing number of input channels"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "data": {
      "text/plain": "[{'size': (1, 64, 128, 128),\n  'mean': 0.26584798097610474,\n  'std': 0.3063778281211853,\n  'dtype': torch.float32},\n {'size': (1, 64, 64, 64),\n  'mean': 0.7738164067268372,\n  'std': 0.6547278165817261,\n  'dtype': torch.float32},\n {'size': (1, 128, 32, 32),\n  'mean': 0.2798773944377899,\n  'std': 0.3327867388725281,\n  'dtype': torch.float32},\n {'size': (1, 256, 16, 16),\n  'mean': 0.10497166216373444,\n  'std': 0.2116292417049408,\n  'dtype': torch.float32},\n {'size': (1, 512, 8, 8),\n  'mean': 0.9424188137054443,\n  'std': 1.2670280933380127,\n  'dtype': torch.float32}]"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = encoder.change_input_channels(1)\n",
    "outputs = encoder(torch.randn(1, 1, 256, 256))\n",
    "describe_outputs(outputs)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:56.755873800Z",
     "start_time": "2023-05-04T20:36:56.606150400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "data": {
      "text/plain": "FeatureMapsSpecification(channels=(36, 72, 144), strides=(8, 16, 32))"
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = encoders.HRNetW18Encoder(pretrained=True, layers=[2, 3, 4], use_incre_features=False)\n",
    "encoder.get_output_spec()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:57.406449200Z",
     "start_time": "2023-05-04T20:36:56.739779800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:57.406449200Z",
     "start_time": "2023-05-04T20:36:57.406449200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:57.407452100Z",
     "start_time": "2023-05-04T20:36:57.406449200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "class GenericSegmentationModel(nn.Module):\n",
    "    def __init__(self, encoder, decoder, head):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.head = head\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.encoder(x)\n",
    "        features = self.decoder(features)\n",
    "        outputs = self.head(features, output_size=x.shape[-2:])\n",
    "        return outputs"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:57.407452100Z",
     "start_time": "2023-05-04T20:36:57.406449200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "from pytorch_toolbelt.modules.heads import ResizeHead\n",
    "\n",
    "\n",
    "def b4_bifpn():\n",
    "    encoder = encoders.TimmB4Encoder(\n",
    "        pretrained=True, layers=[1, 2, 3, 4], drop_path_rate=0.2, activation=ACT_SILU\n",
    "    )\n",
    "    decoder = decoders.BiFPNDecoder(\n",
    "        input_spec=encoder.get_output_spec(),\n",
    "        out_channels=256,\n",
    "        num_layers=3,\n",
    "        activation=ACT_SILU\n",
    "    )\n",
    "    head = ResizeHead(\n",
    "        input_spec=decoder.get_output_spec(),\n",
    "        num_classes=1,\n",
    "        dropout_rate=0.2,\n",
    "    )\n",
    "    return GenericSegmentationModel(\n",
    "        encoder,\n",
    "        decoder,\n",
    "        head,\n",
    "    )\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:57.407452100Z",
     "start_time": "2023-05-04T20:36:57.406449200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'total': '27.5M', 'trainable': '27.5M', 'encoder': '16.7M', 'decoder': '10.8M', 'head': '2.31K'}\n",
      "{'size': (1, 1, 256, 256), 'mean': -0.0968250185251236, 'std': 0.28567764163017273, 'dtype': torch.float32}\n"
     ]
    }
   ],
   "source": [
    "from pytorch_toolbelt.utils import count_parameters\n",
    "\n",
    "model = b4_bifpn()\n",
    "output = model(torch.randn(1, 3, 256, 256))\n",
    "\n",
    "print(count_parameters(model, human_friendly=True))\n",
    "print(describe_outputs(output))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:38:00.258790800Z",
     "start_time": "2023-05-04T20:37:59.381824100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "data": {
      "text/plain": "{'size': (1, 1, 256, 256),\n 'mean': -0.3548843264579773,\n 'std': 0.3300875425338745,\n 'dtype': torch.float32}"
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "describe_outputs(output)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:58.308492200Z",
     "start_time": "2023-05-04T20:36:57.856172100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "\n",
    "def hrnet_fpn():\n",
    "    encoder = encoders.HRNetW18Encoder(\n",
    "        pretrained=True, layers=[1, 2, 3, 4], use_incre_features=False,\n",
    "    )\n",
    "    decoder = decoders.FPNDecoder(\n",
    "        input_spec=encoder.get_output_spec(),\n",
    "        out_channels=256,\n",
    "    )\n",
    "    head = ResizeHead(\n",
    "        input_spec=decoder.get_output_spec(),\n",
    "        num_classes=1,\n",
    "        dropout_rate=0.2,\n",
    "    )\n",
    "    return GenericSegmentationModel(\n",
    "        encoder,\n",
    "        decoder,\n",
    "        head,\n",
    "    )\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:58.308492200Z",
     "start_time": "2023-05-04T20:36:58.307491500Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'total': '11.4M', 'trainable': '11.4M', 'encoder': '9.56M', 'decoder': '1.84M', 'head': '2.31K'}\n",
      "{'size': (1, 1, 256, 256), 'mean': 0.34645456075668335, 'std': 0.1458025723695755, 'dtype': torch.float32}\n"
     ]
    }
   ],
   "source": [
    "model = hrnet_fpn()\n",
    "output = model(torch.randn(1, 3, 256, 256))\n",
    "\n",
    "print(count_parameters(model, human_friendly=True))\n",
    "print(describe_outputs(output))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:37:44.923060600Z",
     "start_time": "2023-05-04T20:37:44.158055300Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'total': '43.7M', 'trainable': '43.7M', 'encoder': '23.5M', 'decoder': '20.2M', 'head': '1.15K'}\n",
      "{'size': (1, 1, 256, 256), 'mean': 0.1509934365749359, 'std': 0.26733869314193726, 'dtype': torch.float32}\n"
     ]
    }
   ],
   "source": [
    "from pytorch_toolbelt.modules import UpsampleLayerType\n",
    "\n",
    "\n",
    "def resnet50d_unet():\n",
    "    encoder = encoders.TimmResnet50D(\n",
    "        pretrained=True, layers=[1, 2, 3, 4],\n",
    "    )\n",
    "    decoder = decoders.UNetDecoder(\n",
    "        input_spec=encoder.get_output_spec(),\n",
    "        out_channels=(128, 256, 512),\n",
    "        upsample_block=UpsampleLayerType.BILINEAR,\n",
    "    )\n",
    "    head = ResizeHead(\n",
    "        input_spec=decoder.get_output_spec(),\n",
    "        num_classes=1,\n",
    "        dropout_rate=0.2,\n",
    "    )\n",
    "    return GenericSegmentationModel(\n",
    "        encoder,\n",
    "        decoder,\n",
    "        head,\n",
    "    )\n",
    "\n",
    "model = resnet50d_unet()\n",
    "output = model(torch.randn(1, 3, 256, 256))\n",
    "print(count_parameters(model, human_friendly=True))\n",
    "print(describe_outputs(output))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:37:35.322656900Z",
     "start_time": "2023-05-04T20:37:34.457745Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-04T20:36:59.872277900Z",
     "start_time": "2023-05-04T20:36:59.856983600Z"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
